# -*- coding: utf-8 -*-
"""
Created on Fri Feb 13 10:25:02 2026

@author: Michael Palace
"""

# -*- coding: utf-8 -*-

import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.cm as cm
from matplotlib.colors import ListedColormap

# -----------------------------
# Load data
# -----------------------------
df = pd.read_csv("crossing-300_208wks_run2.csv")

# -----------------------------
# Count repeated connections
# -----------------------------
edge_counts = (
    df.groupby(['Female Tree', 'Male Tree'])
      .size()
      .reset_index(name='weight')
)

# -----------------------------
# Build weighted graph
# -----------------------------
G = nx.Graph()
pos = {}

for _, row in edge_counts.iterrows():
    G.add_edge(
        row['Female Tree'],
        row['Male Tree'],
        weight=row['weight']
    )

for _, row in df.iterrows():
    pos[row['Female Tree']] = (row['X'], row['Y'])
    pos[row['Male Tree']] = (row['X2'], row['Y2'])

# -----------------------------
# Node sizes (degree)
# -----------------------------
degrees = dict(G.degree())

def scale_node_size(deg):
    return 100 + 50 * deg

node_sizes = [scale_node_size(degrees[n]) for n in G.nodes()]

# -----------------------------
# Edge weights → 5 bins
# -----------------------------
weights = np.array([G[u][v]['weight'] for u, v in G.edges()])

# Create 5 bins
bins = np.linspace(weights.min(), weights.max(), 6)
bin_indices = np.digitize(weights, bins, right=True)

# Discrete colormap (5 colors)
cmap = cm.plasma(np.linspace(0.2, 0.9, 5))
edge_colors = [cmap[i-1] if i > 0 else cmap[0] for i in bin_indices]

# Scale edge width moderately
edge_widths = 1 + 0.5 * bin_indices

# -----------------------------
# Plot
# -----------------------------
fig, ax = plt.subplots(figsize=(70, 50), constrained_layout=True)

nx.draw(
    G, pos,
    ax=ax,
    with_labels=False,
    node_size=node_sizes,
    node_color='skyblue',
    edgecolors='black',
    linewidths=1.5,
    edge_color=edge_colors,
    width=edge_widths
)

#ax.set_title("Tree Network on Landscape")
ax.set_xlabel("X coordinate")
ax.set_ylabel("Y coordinate")

# =====================================================
# NODE LEGEND (Top Right)
# =====================================================
min_deg = min(degrees.values())
max_deg = max(degrees.values())

legend_degrees = np.linspace(min_deg, max_deg, 5)
legend_degrees = np.round(legend_degrees).astype(int)
legend_degrees = np.unique(legend_degrees)

node_handles = [
    ax.scatter([], [],
               s=scale_node_size(d),
               color='skyblue',
               edgecolors='black',
               label=f'{d} links')
    for d in legend_degrees
]

legend_nodes = ax.legend(
    handles=node_handles,
    title="Node Degree",
    #loc='upper right',
    frameon=True,
    fontsize=100,
    loc='center left',
    bbox_to_anchor=(1.15, 0.70),   # farther right + upper middle
    #frameon=True,
    title_fontsize=100
)

ax.add_artist(legend_nodes)

# =====================================================
# EDGE LEGEND (Lower Right)
# =====================================================
edge_labels = []

for i in range(5):
    lower = int(np.round(bins[i]))
    upper = int(np.round(bins[i+1]))
    edge_labels.append(f"{lower}–{upper}")

edge_handles = [
    plt.Line2D([0], [0],
               color=cmap[i],
               linewidth=3 + i,
               label=edge_labels[i])
    for i in range(5)
]

legend_edges = ax.legend(
    handles=edge_handles,
    title="Edge Weight\n(# Connections)",
    #loc='lower right',
    frameon=True,
    fontsize=100,
    loc='center left',
    bbox_to_anchor=(1.15, 0.35),   # farther right + lower middle
    #frameon=True,
    title_fontsize=100
)

# -----------------------------
# Save + show
# -----------------------------
plt.savefig("tree_network_weighted_binned.png", dpi=600)
plt.show()




# --------------------------------------------------
# BASIC DEGREE METRICS
# --------------------------------------------------

# Degree dictionary
#already calculated earlier in code
#degrees = dict(G.degree())

# Average degree
avg_degree = np.mean(list(degrees.values()))

print("Average Degree:", avg_degree)

# --------------------------------------------------
# CLUSTERING COEFFICIENT
# --------------------------------------------------

# Average clustering coefficient (unweighted)
avg_clustering = nx.average_clustering(G)

print("Average Clustering Coefficient:", avg_clustering)

# If using weighted edges:
avg_clustering_weighted = nx.average_clustering(G, weight='weight')

print("Weighted Average Clustering Coefficient:", avg_clustering_weighted)

# --------------------------------------------------
# PATH LENGTH
# --------------------------------------------------

# Path length only works on connected components
if nx.is_connected(G):
    avg_path_length = nx.average_shortest_path_length(G)
    print("Average Shortest Path Length:", avg_path_length)
else:
    # Compute for largest connected component
    largest_cc = max(nx.connected_components(G), key=len)
    G_lcc = G.subgraph(largest_cc)

    avg_path_length = nx.average_shortest_path_length(G_lcc)
    print("Average Shortest Path Length (Largest Component):", avg_path_length)
    print("Graph was not fully connected.")



